# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import asyncio
import random
from datetime import datetime
from typing import Iterable, List, Optional, cast

from pydantic import BaseModel, Field, ValidationError

from camel.agents import ChatAgent
from camel.logger import get_logger
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType
from camel.verifiers import BaseVerifier

from .base_generator import BaseGenerator
from .models import DataPoint
from .static_dataset import StaticDataset

logger = get_logger(__name__)

DEFAULT_INSTRUCTION_SYSTEM_PROMPT = """
You are a high-capacity instruction generation assistant.

Your task is to generate a **new, creative, and challenging question** based on
several examples.
These examples may cover different domains or styles, but your goal is to:
- **Understand their specific patterns** in structure, and complexity;
- **Combine and synthesize** ideas from multiple examples, rather than copying
  or lightly editing any single one;
- **Intelligently integrate** multiple reasoning steps, constraints, or
  concepts into a single, coherent question;
- Ensure the new question is **non-trivial** and requires deep thinking or
  multi-step reasoning.

**Guidelines:**
- Use the examples as inspiration for format, depth, and tone.
- Your new question should be self-contained, logically sound, and answerable.
- Do not repeat exact phrasings or create shallow combinations; instead,
  produce something meaningfully new.
- Avoid open-ended or subjective questions that depend on personal opinions or
  discussion.
- The generated question must have a **clear, objective, and verifiable
  answer**.
- Aim for increased depth or novelty through subtle combination or
  transformation.
- Keep the final output to a **single unified question** with one clear answer,
  not a multi-part task.

**Output Format (strict):**
```
Question: [Generated question]
```
"""

DEFAULT_RATIONALE_SYSTEM_PROMPT = """You are an advanced Python code assistant.

Your task is to **solve the given question by writing Python code only**,
without any explanation or natural language output.
The code must compute the answer **programmatically**, not by hardcoding or
guessing the result.

**Rules:**
- Use Python code to perform the actual computation.
- Use {package_list} to solve the problem. Do not import any other libraries.
- **Do not hardcode the final answer** (e.g., avoid writing `print(1/2)` unless
  that value is computed).
- The result must be obtained through valid computation logic in code.
- Do not include explanations. Output code only.
- The entire code must be wrapped in triple backticks:
```
[Your Python code here]
```

Now, solve the following question using Python. Only output the code:
"""


class SelfInstructGenerator(BaseGenerator):
    r"""A generator for creating synthetic datapoints using self-instruct.

    It utilizes both a human-provided dataset (seed_dataset) and generated
    machine instructions (machine_instructions) to produce new, synthetic
    datapoints that include a question, a computed rationale (code), and a
    final answer (from a verifier).
    """

    class QuestionSchema(BaseModel):
        r"""Schema for the generated question.

        Attributes:
            question (str): The question generated by the model.
        """

        question: str = Field(description="The question generated")

    class RationaleSchema(BaseModel):
        r"""Schema for the generated rationale code.

        Attributes:
            code (str): The generated code without any formatting.
        """

        code: str = Field(
            description="The generated code without any formatting"
        )

    def __init__(
        self,
        seed_dataset: StaticDataset,
        verifier: BaseVerifier,
        instruction_agent: Optional[ChatAgent] = None,
        rationale_agent: Optional[ChatAgent] = None,
        seed: int = 42,
        **kwargs,
    ):
        r"""Initialize the self-instruct generator.

        Args:
            seed_dataset (StaticDataset): Dataset containing seed instructions.
            verifier (BaseVerifier): Verifier instance to validate generated
                solutions.
            instruction_agent (Optional[ChatAgent]): Agent for generating
                instructions. If not provided, a default agent will be created.
            rationale_agent (Optional[ChatAgent]): Agent for generating
                rationales. If not provided, a default agent will be created.
            seed (int): Random seed for reproducibility. (default: :obj:`42`)
            **kwargs: Additional keyword arguments passed to the BaseGenerator.
        """
        super().__init__(seed=seed, **kwargs)
        self.seed_dataset = seed_dataset
        self.verifier = verifier
        # extract packages from verifier
        self.packages: List[str] = getattr(
            self.verifier, "required_packages", []
        )
        # create default agents if not provided
        self.instruction_agent = (
            instruction_agent or self.default_instruction_agent()
        )
        self.rationale_agent = (
            rationale_agent or self.default_rationale_agent()
        )

        # Extract questions from the seed dataset as human_instructions
        self.human_instructions: List[str] = [
            dp.question
            for dp in list(cast(Iterable[DataPoint], self.seed_dataset))
        ]
        self.machine_instructions: List[DataPoint] = []
        # Create an instance-level lock for thread-safe updates to _data
        self._lock = asyncio.Lock()
        self._data = []  # Storage for generated DataPoint instances

    def default_instruction_agent(self) -> ChatAgent:
        r"""Create the default instruction generation agent.

        This agent is configured with a moderate temperature setting to
        encourage creative and diverse instruction generation behavior.

        Returns:
            ChatAgent: An agent with the default instruction prompt.
        """
        model = ModelFactory.create(
            model_platform=ModelPlatformType.DEFAULT,
            model_type=ModelType.DEFAULT,
            model_config_dict={"temperature": 0.7},
        )
        return ChatAgent(
            DEFAULT_INSTRUCTION_SYSTEM_PROMPT,
            model=model,
        )

    def default_rationale_agent(self) -> ChatAgent:
        r"""Create the default rationale generation agent.

        This agent is configured with a deterministic (zero temperature)
        setting to ensure consistent and precise rationale generation based on
        a given instruction and package list.

        Returns:
            ChatAgent: An agent with the rationale prompt
        """
        model = ModelFactory.create(
            model_platform=ModelPlatformType.DEFAULT,
            model_type=ModelType.DEFAULT,
            model_config_dict={"temperature": 0.0},
        )
        return ChatAgent(
            DEFAULT_RATIONALE_SYSTEM_PROMPT.format(package_list=self.packages),
            model=model,
        )

    @staticmethod
    def format_support_block(dp: DataPoint) -> str:
        r"""Format a DataPoint into a few-shot example block.

        Args:
            dp (DataPoint): A data point.

        Returns:
            str: A formatted string containing the question and its
                corresponding code block in Markdown-style Python format.
        """
        support_q = dp.question.strip()
        support_code = dp.rationale.strip() if dp.rationale else ""
        return (
            f"Question:\n{support_q}\n\n"
            "Code:\n"
            "```python\n"
            f"{support_code}\n"
            "```"
        )

    def generate_new_instruction(
        self,
        agent: ChatAgent,
        support_human_dps: list[DataPoint],
        support_machine_dps: list[DataPoint],
    ) -> str:
        r"""Generate a new instruction using self-instruct prompting.

        Args:
            agent (ChatAgent): The agent to use for generating the instruction.
            support_human_dps (list[DataPoint]): List of human examples to
                sample.
            support_machine_dps (list[DataPoint]): List of machine examples to
                sample.

        Returns:
            str: The newly generated question.
        """
        human_sample = [dp.question for dp in list(support_human_dps)]
        machine_sample = [dp.question for dp in list(support_machine_dps)]

        few_shot_examples = human_sample + machine_sample

        # Build the prompt using the few-shot examples
        prompt = "Below are some question examples:\n\n"
        for idx, instr in enumerate(few_shot_examples, start=1):
            prompt += f"Question {idx}: {instr}\n"
        prompt += f"Question {len(few_shot_examples) + 1}:\n"
        prompt += "Now generate a new question based on the given examples.\n"

        question_template = f"Question: {prompt}"
        response = cast(
            SelfInstructGenerator.QuestionSchema,
            agent.step(question_template, response_format=self.QuestionSchema)
            .msgs[0]
            .parsed,
        )
        return response.question

    def generate_rationale(
        self,
        question: str,
        agent: Optional[ChatAgent] = None,
        support_human_dps: Optional[list[DataPoint]] = None,
    ) -> str:
        r"""Generate rationale code (solution) for the given question.

        Args:
            question (str): The question to be solved.
            agent (Optional[ChatAgent]): The agent to use for generating the
                rationale. If None is provided, the default rationale agent
                will be used. (default: :obj:`None`)
            support_human_dps (Optional[list[DataPoint]]): List of human
                examples to sample. (default: :obj:`None`)

        Returns:
            str: The generated code solution as a string.
        """

        # Build few-shot example prompt
        few_shot_prompt = ""
        if support_human_dps:
            few_shot_examples = [
                self.format_support_block(dp) for dp in support_human_dps
            ]
            few_shot_prompt += "Below are example questions and solutions:\n\n"
            few_shot_prompt += "\n\n".join(few_shot_examples)

        few_shot_prompt += f"\n\nWrite code to solve the question:\n{question}"

        response = cast(
            SelfInstructGenerator.RationaleSchema,
            (agent or self.default_rationale_agent())
            .step(few_shot_prompt, response_format=self.RationaleSchema)
            .msgs[0]
            .parsed,
        )
        return response.code

    async def generate_new(
        self,
        n: int,
        max_retries: int = 10,
        human_sample_count: int = 3,
        machine_sample_count: int = 1,
        **kwargs,
    ) -> None:
        r"""Generates and validates `n` new datapoints through
        self-instruct prompting, with a retry limit.

        Args:
            n (int): The number of valid datapoints to generate.
            max_retries (int): Maximum number of retries before stopping.
                (default: :obj:`10`)
            human_sample_count (int): Number of human examples to sample.
                (default: :obj:`3`)
            machine_sample_count (int): Number of machine examples to sample.
                (default: :obj:`1`)
            **kwargs: Additional keyword arguments.

        Notes:
            - Retries on validation failures until `n` valid datapoints exist
                or `max_retries` is reached, whichever comes first.
            - If retries are exhausted before reaching `n`, a `RuntimeError`
                is raised.
            - Metadata includes a timestamp for tracking datapoint creation.
        """
        valid_data_points: list[DataPoint] = []
        retries = 0

        while len(valid_data_points) < n and retries < max_retries:
            try:
                human_dps_list = list(cast(List[DataPoint], self.seed_dataset))
                support_human_dps = random.sample(
                    human_dps_list,
                    min(human_sample_count, len(human_dps_list)),
                )

                machine_dps_list = list(self.machine_instructions)
                support_machine_dps = []
                if machine_dps_list and machine_sample_count > 0:
                    support_machine_dps = random.sample(
                        machine_dps_list,
                        min(machine_sample_count, len(machine_dps_list)),
                    )
                question = self.generate_new_instruction(
                    self.instruction_agent,
                    support_human_dps,
                    support_machine_dps,
                )
                rationale = self.generate_rationale(
                    question, self.rationale_agent, support_human_dps
                )
                if not isinstance(rationale, str):
                    raise TypeError(f"Rationale {rationale} is not a string.")

                try:
                    verifier_response = await self.verifier.verify(
                        solution=rationale,
                        reference_answer=None,
                    )
                    if not verifier_response or not verifier_response.result:
                        raise ValueError(
                            "Verifier unsuccessful, response: "
                            f"{verifier_response}"
                        )
                except (ValueError, AttributeError) as e:
                    logger.warning(
                        f"Verifier issue: {e}, "
                        f"retrying... ({retries + 1}/{max_retries})"
                    )
                    retries += 1
                    continue
                try:
                    new_datapoint = DataPoint(
                        question=question,
                        rationale=rationale,
                        final_answer=verifier_response.result,
                        metadata={
                            "synthetic": str(True),
                            "created": datetime.now().isoformat(),
                            "generator": "self_instruct",
                        },
                    )
                except ValidationError as e:
                    logger.warning(
                        f"Datapoint validation failed: {e}, "
                        f"retrying... ({retries + 1}/{max_retries})"
                    )
                    retries += 1
                    continue

                valid_data_points.append(new_datapoint)

            except Exception as e:
                logger.warning(
                    f"Unexpected error: {e}, retrying..."
                    f" ({retries + 1}/{max_retries})"
                )
                retries += 1

        if len(valid_data_points) < n:
            raise RuntimeError(
                f"Failed to generate {n} valid datapoints "
                f"after {max_retries} retries."
            )

        async with self._lock:
            self._data.extend(valid_data_points)
